{
 "cells": [
  {
   "cell_type": "markdown",
   "id": "14f6ca8e",
   "metadata": {},
   "source": [
    "# Benchmarking Classifiers\n",
    "\n",
    "The `benchmarking` module in `sktime` enables easy benchmarking of `sktime` and `sktime` compatible classifiers. Benchmarking can be done across a combination of time series classification models and tasks, where, tasks can be further a combination of datasets, splitting strategies and scorers.\n",
    "\n",
    "This notebook demonstrates a `classifier` benchmark run."
   ]
  },
  {
   "cell_type": "markdown",
   "id": "71de931d",
   "metadata": {},
   "source": [
    "### Imports"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 78,
   "id": "56808fba",
   "metadata": {},
   "outputs": [],
   "source": [
    "from sklearn.metrics import accuracy_score, brier_score_loss\n",
    "from sklearn.model_selection import KFold\n",
    "\n",
    "from sktime.benchmarking.classification import ClassificationBenchmark\n",
    "from sktime.classification.distance_based import KNeighborsTimeSeriesClassifier\n",
    "from sktime.classification.dummy import DummyClassifier\n",
    "from sktime.datasets import load_unit_test"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "1bc813ab",
   "metadata": {},
   "source": [
    "### Instantiate `ClassificationBenchmark` class"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 79,
   "id": "d5d1b325",
   "metadata": {},
   "outputs": [],
   "source": [
    "benchmark = ClassificationBenchmark()"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "74c1ea99",
   "metadata": {},
   "source": [
    "### Add `classifiers` which needs to be benchmarked"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 80,
   "id": "e893c850",
   "metadata": {},
   "outputs": [],
   "source": [
    "benchmark.add_estimator(\n",
    "    estimator=DummyClassifier(),\n",
    "    estimator_id=\"DummyClassifier\",\n",
    ")\n",
    "benchmark.add_estimator(\n",
    "    estimator=KNeighborsTimeSeriesClassifier(),\n",
    "    estimator_id=\"KNeighborsTimeSeriesClassifier\",\n",
    ")"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "219337ea",
   "metadata": {},
   "source": [
    "### Specify cross-validation splitting strategy"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 81,
   "id": "3c48158c",
   "metadata": {},
   "outputs": [],
   "source": [
    "n_splits = 3\n",
    "cv = KFold(n_splits=n_splits)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "743b85fe",
   "metadata": {},
   "source": [
    "### Specify performance metrics"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 82,
   "id": "9634f5b6",
   "metadata": {},
   "outputs": [],
   "source": [
    "scorers = [accuracy_score, brier_score_loss]"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "d41f1238",
   "metadata": {},
   "source": [
    "### Specify dataset loaders"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 83,
   "id": "c960deda",
   "metadata": {},
   "outputs": [],
   "source": [
    "dataset_loaders = [load_unit_test]"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "a5161e0f",
   "metadata": {},
   "source": [
    "### Add tasks to the `ClassificationBenchmarking` instance"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 84,
   "id": "36bbebc5",
   "metadata": {},
   "outputs": [],
   "source": [
    "for dataset_loader in dataset_loaders:\n",
    "    benchmark.add_task(\n",
    "        dataset_loader,\n",
    "        cv,\n",
    "        scorers,\n",
    "    )"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "cdf6e37b",
   "metadata": {},
   "source": [
    "### Run all classifier-task combinations and save the result"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 85,
   "id": "e0406c2c",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/html": [
       "<div>\n",
       "<style scoped>\n",
       "    .dataframe tbody tr th:only-of-type {\n",
       "        vertical-align: middle;\n",
       "    }\n",
       "\n",
       "    .dataframe tbody tr th {\n",
       "        vertical-align: top;\n",
       "    }\n",
       "\n",
       "    .dataframe thead th {\n",
       "        text-align: right;\n",
       "    }\n",
       "</style>\n",
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: right;\">\n",
       "      <th></th>\n",
       "      <th>0</th>\n",
       "      <th>1</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <th>validation_id</th>\n",
       "      <td>[dataset=load_unit_test]_[cv_splitter=KFold]</td>\n",
       "      <td>[dataset=load_unit_test]_[cv_splitter=KFold]</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>model_id</th>\n",
       "      <td>DummyClassifier</td>\n",
       "      <td>KNeighborsTimeSeriesClassifier</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>accuracy_score_fold_0_test</th>\n",
       "      <td>0.285714</td>\n",
       "      <td>0.928571</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>accuracy_score_fold_1_test</th>\n",
       "      <td>0.571429</td>\n",
       "      <td>1.0</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>accuracy_score_fold_2_test</th>\n",
       "      <td>0.285714</td>\n",
       "      <td>0.857143</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>accuracy_score_mean</th>\n",
       "      <td>0.380952</td>\n",
       "      <td>0.928571</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>accuracy_score_std</th>\n",
       "      <td>0.164957</td>\n",
       "      <td>0.071429</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>brier_score_loss_fold_0_test</th>\n",
       "      <td>0.326531</td>\n",
       "      <td>0.357143</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>brier_score_loss_fold_1_test</th>\n",
       "      <td>0.25</td>\n",
       "      <td>0.428571</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>brier_score_loss_fold_2_test</th>\n",
       "      <td>0.127551</td>\n",
       "      <td>0.571429</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>brier_score_loss_mean</th>\n",
       "      <td>0.234694</td>\n",
       "      <td>0.452381</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>brier_score_loss_std</th>\n",
       "      <td>0.100369</td>\n",
       "      <td>0.109109</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>fit_time_fold_0_test</th>\n",
       "      <td>0.009119</td>\n",
       "      <td>0.052076</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>fit_time_fold_1_test</th>\n",
       "      <td>0.008259</td>\n",
       "      <td>0.006292</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>fit_time_fold_2_test</th>\n",
       "      <td>0.049507</td>\n",
       "      <td>0.00181</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>fit_time_mean</th>\n",
       "      <td>0.022295</td>\n",
       "      <td>0.020059</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>fit_time_std</th>\n",
       "      <td>0.02357</td>\n",
       "      <td>0.027818</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>pred_time_fold_0_test</th>\n",
       "      <td>0.002845</td>\n",
       "      <td>0.196373</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>pred_time_fold_1_test</th>\n",
       "      <td>0.002953</td>\n",
       "      <td>0.202413</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>pred_time_fold_2_test</th>\n",
       "      <td>0.001597</td>\n",
       "      <td>0.195705</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>pred_time_mean</th>\n",
       "      <td>0.002465</td>\n",
       "      <td>0.198164</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>pred_time_std</th>\n",
       "      <td>0.000754</td>\n",
       "      <td>0.003695</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>runtime_secs</th>\n",
       "      <td>0.024761</td>\n",
       "      <td>0.218223</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>\n",
       "</div>"
      ],
      "text/plain": [
       "                                                                         0  \\\n",
       "validation_id                 [dataset=load_unit_test]_[cv_splitter=KFold]   \n",
       "model_id                                                   DummyClassifier   \n",
       "accuracy_score_fold_0_test                                        0.285714   \n",
       "accuracy_score_fold_1_test                                        0.571429   \n",
       "accuracy_score_fold_2_test                                        0.285714   \n",
       "accuracy_score_mean                                               0.380952   \n",
       "accuracy_score_std                                                0.164957   \n",
       "brier_score_loss_fold_0_test                                      0.326531   \n",
       "brier_score_loss_fold_1_test                                          0.25   \n",
       "brier_score_loss_fold_2_test                                      0.127551   \n",
       "brier_score_loss_mean                                             0.234694   \n",
       "brier_score_loss_std                                              0.100369   \n",
       "fit_time_fold_0_test                                              0.009119   \n",
       "fit_time_fold_1_test                                              0.008259   \n",
       "fit_time_fold_2_test                                              0.049507   \n",
       "fit_time_mean                                                     0.022295   \n",
       "fit_time_std                                                       0.02357   \n",
       "pred_time_fold_0_test                                             0.002845   \n",
       "pred_time_fold_1_test                                             0.002953   \n",
       "pred_time_fold_2_test                                             0.001597   \n",
       "pred_time_mean                                                    0.002465   \n",
       "pred_time_std                                                     0.000754   \n",
       "runtime_secs                                                      0.024761   \n",
       "\n",
       "                                                                         1  \n",
       "validation_id                 [dataset=load_unit_test]_[cv_splitter=KFold]  \n",
       "model_id                                    KNeighborsTimeSeriesClassifier  \n",
       "accuracy_score_fold_0_test                                        0.928571  \n",
       "accuracy_score_fold_1_test                                             1.0  \n",
       "accuracy_score_fold_2_test                                        0.857143  \n",
       "accuracy_score_mean                                               0.928571  \n",
       "accuracy_score_std                                                0.071429  \n",
       "brier_score_loss_fold_0_test                                      0.357143  \n",
       "brier_score_loss_fold_1_test                                      0.428571  \n",
       "brier_score_loss_fold_2_test                                      0.571429  \n",
       "brier_score_loss_mean                                             0.452381  \n",
       "brier_score_loss_std                                              0.109109  \n",
       "fit_time_fold_0_test                                              0.052076  \n",
       "fit_time_fold_1_test                                              0.006292  \n",
       "fit_time_fold_2_test                                               0.00181  \n",
       "fit_time_mean                                                     0.020059  \n",
       "fit_time_std                                                      0.027818  \n",
       "pred_time_fold_0_test                                             0.196373  \n",
       "pred_time_fold_1_test                                             0.202413  \n",
       "pred_time_fold_2_test                                             0.195705  \n",
       "pred_time_mean                                                    0.198164  \n",
       "pred_time_std                                                     0.003695  \n",
       "runtime_secs                                                      0.218223  "
      ]
     },
     "execution_count": 85,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "results_df = benchmark.run(\"./classifier_benchmarking_results.csv\")\n",
    "results_df.T"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "0f0f05ee",
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "sktime-dev",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.11.11"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
